# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DPO Config."""
import warnings
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Literal, Optional

from mindnlp.engine import TrainingArguments


class FDivergenceType(Enum):
    """FDivergenceType"""
    REVERSE_KL = "reverse_kl"
    JS_DIVERGENCE = "js_divergence"
    ALPHA_DIVERGENCE = "alpha_divergence"


class FDivergenceConstants:
    """FDivergenceConstants."""
    ALPHA_DIVERGENCE_COEF_KEY = "alpha_divergence_coef"
    ALPHA_DIVERGENCE_COEF_DEFAULT = 1.0


@dataclass
class DPOConfig(TrainingArguments):
    r"""
    Configuration class for the [`DPOTrainer`].

    Using [`~transformers.HfArgumentParser`] we can turn this class into
    [argparse] arguments that can be specified on the command line.

    Parameters:
        beta (`float`, *optional*, defaults to `0.1`):
            Parameter controlling the deviation from the reference model. Higher β means less 
            deviation from the reference model. For the IPO loss (`loss_type="ipo"`), β is the 
            regularization parameter denoted by τ in the 
            [paper](https://huggingface.co/papers/2310.12036).
        label_smoothing (`float`, *optional*, defaults to `0.0`):
            Robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/
            cdpo.pdf) report and [Robust DPO](https://huggingface.co/papers/2403.00409)
            paper that should be between `0.0` and `0.5`.
        loss_type (`str`, *optional*, defaults to `"sigmoid"`):
            Type of loss to use. Possible values are:

                - `"sigmoid"`: sigmoid loss from the original 
                [DPO](https://huggingface.co/papers/2305.18290) paper.
                - `"hinge"`: hinge loss on the normalized likelihood from the 
                [SLiC](https://huggingface.co/papers/2305.10425) paper.
                - `"ipo"`: IPO loss from the 
                [IPO](https://huggingface.co/papers/2310.12036) paper.
                - `"exo_pair"`: pairwise EXO loss from the 
                [EXO](https://huggingface.co/papers/2402.00856) paper.
                - `"nca_pair"`: pairwise NCA loss from the 
                [NCA](https://huggingface.co/papers/2402.05369) paper.
                - `"robust"`: unbiased estimate of the DPO loss that is robust
                to preference noise from the 
                [Robust DPO](https://huggingface.co/papers/2403.00409) paper.
                - `"bco_pair"`: pairwise BCO loss from the 
                [BCO](https://huggingface.co/papers/2404.04656) paper.
                - `"sppo_hard"`: SPPO loss with hard label from the 
                [SPPO](https://huggingface.co/papers/2405.00675) paper.
                - `"aot"`: AOT loss for paired datasets from the 
                [AOT](https://huggingface.co/papers/2406.05882) paper.
                - `"aot_pair"`: AOT loss for unpaired datasets from the 
                [AOT](https://huggingface.co/papers/2406.05882) paper.
                - `"apo_zero"`: APO-zero loss from the 
                [APO](https://huggingface.co/papers/2408.06266) paper.
                - `"apo_down"`: APO-down loss from the 
                [APO](https://huggingface.co/papers/2408.06266) paper.

        label_pad_token_id (`int`, *optional*, defaults to `-100`):
            Label pad token id. This argument is required if you want to use the
            default data collator.
        padding_value (`Optional[int]`, *optional*, defaults to `None`):
            Padding value to use. If `None`, the padding value of the tokenizer
            is used.
        truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
            Truncation mode to use, either `keep_end` or `keep_start`. This argument 
            is required if you want to use the default data collator.
        max_length (`Optional[int]`, *optional*, defaults to `None`):
            Maximum length of the sequences (prompt + completion) in the batch.
            This argument 
            is required if you want to use the default data collator.
        max_prompt_length (`Optional[int]`, *optional*, defaults to `None`):
            Maximum length of the prompt. This argument is required if you
            want to use the default data collator.
        max_completion_length (`Optional[int]`, *optional*, defaults to `None`):
            Maximum length of the target. This argument is required if you want
            to use the default data collator and your model is an encoder-decoder.
        is_encoder_decoder(`Optional[int]`, *optional*, defaults to `None`):
            When using the `model_init` argument (callable) to instantiate the model 
            instead of the `model` argument, you need to specify if the model
            returned by the callable is an encoder-decoder model.
        disable_dropout (`bool`, *optional*, defaults to `True`):
            Whether to disable dropout in the model and reference model.
        generate_during_eval (`bool`, *optional*, defaults to `False`):
            Truncation mode to use when the prompt is too long. Possible values are 
            `"keep_end"` or `"keep_start"`.This argument is required if you want to use 
            the default data collator.
        precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
            Whether to precompute reference model log probabilities for training and
            evaluation datasets. This is useful when training without the reference
            model to reduce the total GPU memory needed.
        dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
            Number of processes to use for processing the dataset.
        model_init_kwargs (`Optional[Dict[str, Any]]`, *optional*, defaults to `None`):
            Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained`
            when instantiating the model from a string.
        ref_model_init_kwargs (`Optional[Dict[str, Any]]`, *optional*, defaults to `None`):
            Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when
            instantiating the reference model from a string.
        model_adapter_name (`Optional[str]`, *optional*, defaults to `None`):
            Name of the train target PEFT adapter, when using LoRA with multiple adapters.
        ref_adapter_name (`Optional[str]`, *optional*, defaults to `None`):
            Name of the reference PEFT adapter, when using LoRA with multiple adapters.
        reference_free (`bool`, *optional*, defaults to `False`):
            If `True`, we ignore the _provided_ reference model and implicitly use a
            reference model that assigns equal probability to all responses.
        force_use_ref_model (`bool`, *optional*, defaults to `False`):
            In case one passes a PEFT model for the active model and you want to
            use a different model for the ref_model, set this flag to `True`.
        f_divergence_type (`str`, *optional*, defaults to `FDivergenceType.REVERSE_KL`):
            Type of f-divergence regularization function to compute divergence between
            policy and reference model.
        f_alpha_divergence_coef (`float`, *optional*, defaults to `1.0`):
            α coefficient in the α-divergence \\(u^{-\\alpha}\\) regularization
            function for DPO loss.
        sync_ref_model (`bool`, *optional*, defaults to `False`):
            When set to `True`, the reference model is synchronized with the active
            model every `ref_model_sync_steps` steps, using the `ref_model_mixup_alpha`
            parameter. This synchronization originites from the
            [TR-DPO](https://huggingface.co/papers/2404.09656) paper.
        ref_model_mixup_alpha (`float`, *optional*, defaults to `0.9`):
            α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper,
            which controls the mix between the current policy and the previous reference
            policy during updates. The reference policy is updated according to the equation:
            `π_ref = α * π_θ + (1 - α) * π_ref_prev`To use this parameter,
            you must set `sync_ref_model=True`.
        ref_model_sync_steps (`int`, *optional*, defaults to `64`):
            τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper,
            which determines how frequently the current policy is synchronized with
            the reference policy. To use this parameter, you must set `sync_ref_model=True`.
        rpo_alpha (`float`, *optional*, defaults to `None`):
            α parameter from the [RPO](https://huggingface.co/papers/2404.19733) paper (v3),
            which controls the weighting of the NLL term in the loss. If `None`, no weighting
            is applied and the loss is the same as the DPO loss. The paper recommends
            `rpo_alpha=1.0`.
    """

    beta: float = 0.1
    label_smoothing: float = 0.0
    loss_type: Literal[
        "sigmoid",
        "hinge",
        "ipo",
        "exo_pair",
        "nca_pair",
        "robust",
        "bco_pair",
        "sppo_hard",
        "aot",
        "aot_pair",
        "apo_zero",
        "apo_down",
    ] = "sigmoid"
    label_pad_token_id: int = -100
    padding_value: Optional[int] = None
    truncation_mode: str = "keep_end"
    max_length: Optional[int] = None
    max_prompt_length: Optional[int] = None
    max_target_length: Optional[int] = None  # deprecated in favor of max_completion_length
    max_completion_length: Optional[int] = None
    is_encoder_decoder: Optional[bool] = None
    disable_dropout: bool = True
    generate_during_eval: bool = False
    precompute_ref_log_probs: bool = False
    dataset_num_proc: Optional[int] = None
    model_init_kwargs: Optional[Dict[str, Any]] = None
    ref_model_init_kwargs: Optional[Dict[str, Any]] = None
    model_adapter_name: Optional[str] = None
    ref_adapter_name: Optional[str] = None
    reference_free: bool = False
    force_use_ref_model: bool = False
    f_divergence_type: FDivergenceType = FDivergenceType.REVERSE_KL
    f_alpha_divergence_coef: float = 1.0
    sync_ref_model: bool = False
    ref_model_mixup_alpha: float = 0.9
    ref_model_sync_steps: int = 64
    rpo_alpha: Optional[float] = None

    def __post_init__(self):
        if self.max_target_length is not None:
            warnings.warn(
                "The `max_target_length` argument is deprecated in favor of "
                "`max_completion_length` and will be removed in a future version.",
                FutureWarning,
            )
            if self.max_completion_length is None:
                self.max_completion_length = self.max_target_length

        return super().__post_init__()
